{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3a10f8f1",
   "metadata": {},
   "source": [
    "##  review"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f82b0ed",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:28:17.611526Z",
     "start_time": "2023-07-08T13:28:17.588817Z"
    }
   },
   "source": [
    "- $q(x)$：from student model，$p(x)$：from teacher model\n",
    "- 其次对于 $q(x), p(x)$ 在计算时需要加温度\n",
    "$$\n",
    "\\begin{split}\n",
    "L_{\\text{student}}&=\\alpha L_{\\text{CE}} + (1-\\alpha)L_{KD}\\\\\n",
    "&=\\alpha L_{\\text{CE}} + (1-\\alpha)T^2D_{KL}\\\\\n",
    "&=\\alpha L_{\\text{CE}} + (1-\\alpha)T^2\\sum_ip_i(x)\\log\\frac{p_i(x)}{q_i(x)}\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "- 关于 `nn.KLDivLoss()`\n",
    "    - inputs ($q(x)$): log probabilities\n",
    "    - labels ($p(x)$): normal probabilities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ede970ba",
   "metadata": {},
   "source": [
    "## trainer arguments & trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8bf76fa8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:42:34.579255Z",
     "start_time": "2023-07-08T13:42:34.573911Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "59dde685",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:42:36.134423Z",
     "start_time": "2023-07-08T13:42:36.126195Z"
    }
   },
   "outputs": [],
   "source": [
    "class DistillTrainingArguments(TrainingArguments):\n",
    "    # TrainingArguments: @dataclass\n",
    "    # 增加两个 KD 所需的参数参数\n",
    "    def __init__(self, *args, alpha=0.5, temperature=2., **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.alpha = alpha\n",
    "        self.temperature = temperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "35763030",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:45:33.019991Z",
     "start_time": "2023-07-08T13:45:33.005716Z"
    }
   },
   "outputs": [],
   "source": [
    "class DistillTrainer(Trainer):\n",
    "    \n",
    "    def __init__(self, *args, teacher_model=None, **kwargs):\n",
    "        # 增加 teacher_model 参数\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.teacher_model = teacher_model\n",
    "        \n",
    "    # 重写 trainer 中核心方法\n",
    "    # forward 计算损失\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        s_output = model(**inputs)\n",
    "        s_ce = s_output.loss\n",
    "        s_logits = s_output.logits\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            t_output = self.teacher_model(**inputs)\n",
    "            t_logits = t_output.logits\n",
    "        \n",
    "        loss_kl_fct = nn.KLDivLoss(reduction='batchmean')\n",
    "        loss_kd = self.args.temperature**2 * loss_kl_fct(F.log_softmax(s_logits/self.args.temperature, dim=-1), \n",
    "                                                        F.softmax(t_logits/self.args.temperature, dim=-1))\n",
    "        loss = self.args.alpha * s_ce + (1-self.args.alpha) * loss_kd\n",
    "        return (loss, s_output) if return_outputs else loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07bd85c3",
   "metadata": {},
   "source": [
    "## pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf2702dd",
   "metadata": {},
   "source": [
    "### datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "79f683fc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:45:41.732252Z",
     "start_time": "2023-07-08T13:45:41.725945Z"
    }
   },
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5b79b20b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:45:54.620037Z",
     "start_time": "2023-07-08T13:45:42.507852Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset clinc_oos (/media/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "727be4daa19b4ad58579fec9fccf439b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "clinc = load_dataset(\"clinc_oos\", \"plus\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e1c17cac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:45:57.176014Z",
     "start_time": "2023-07-08T13:45:57.167301Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'intent'],\n",
       "        num_rows: 15250\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'intent'],\n",
       "        num_rows: 3100\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'intent'],\n",
       "        num_rows: 5500\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clinc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9f258185",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:46:17.612259Z",
     "start_time": "2023-07-08T13:46:17.599359Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['what expression would i use to say i love you if i were an italian',\n",
       "  \"can you tell me how to say 'i do not speak much spanish', in spanish\",\n",
       "  \"what is the equivalent of, 'life is good' in french\",\n",
       "  \"tell me how to say, 'it is a beautiful morning' in italian\",\n",
       "  'if i were mongolian, how would i say that i am a tourist',\n",
       "  \"how do i say 'hotel' in finnish\",\n",
       "  \"i need you to translate the sentence, 'we will be there soon' into portuguese\",\n",
       "  'please tell me how to ask for a taxi in french',\n",
       "  \"can you tell me how i would say, 'more bread please' in french\",\n",
       "  \"what is the correct way to say 'i am a visitor' in french\"],\n",
       " 'intent': [61, 61, 61, 61, 61, 61, 61, 61, 61, 61]}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clinc['train'][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4366a862",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:46:29.621894Z",
     "start_time": "2023-07-08T13:46:29.611922Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "151"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intents = clinc['train'].features['intent']\n",
    "num_labels = intents.num_classes\n",
    "num_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6a4a0a9",
   "metadata": {},
   "source": [
    "### Student model 初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "313668df",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:46:48.960423Z",
     "start_time": "2023-07-08T13:46:48.949623Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoConfig, AutoTokenizer\n",
    "from transformers import AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a68a82b7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:47:49.353462Z",
     "start_time": "2023-07-08T13:47:00.143015Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4c3fc332a3e4904820bde274d27ddd8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d0a9e1f4b964726b1454f5938404c65",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "246915b110a74cd0b1655682d24522cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8df2f72cac584fba95f216bcc783011b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c8568dcf0004557b64a93bedf56c71c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/8.18k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "010e2972de19439b8e2e77d9152e7afe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "s_ckpt = 'distilbert-base-uncased'\n",
    "s_tokenizer = AutoTokenizer.from_pretrained(s_ckpt)\n",
    "\n",
    "t_ckpt = 'transformersbook/bert-base-uncased-finetuned-clinc'\n",
    "t_model = AutoModelForSequenceClassification.from_pretrained(t_ckpt, num_labels=num_labels).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "30765bdf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:48:33.125442Z",
     "start_time": "2023-07-08T13:48:32.121377Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/15250 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/3100 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/5500 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['labels', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 15250\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['labels', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 3100\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['labels', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 5500\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clinc_enc = clinc.map(lambda batch: s_tokenizer(batch['text'], truncation=True), \n",
    "                      batched=True, \n",
    "                      remove_columns=[\"text\"]\n",
    "                     )\n",
    "clinc_enc = clinc_enc.rename_columns({'intent': 'labels'})\n",
    "clinc_enc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "c9a08ae2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:50:12.756586Z",
     "start_time": "2023-07-08T13:50:12.379492Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "s_training_args = DistillTrainingArguments(output_dir='distilbert-base-uncased-ft-clinc', \n",
    "                                           evaluation_strategy='epoch', num_train_epochs=5, \n",
    "                                           learning_rate=3e-4, \n",
    "                                           per_device_train_batch_size=batch_size, \n",
    "                                           per_device_eval_batch_size=batch_size, \n",
    "                                           alpha=0.5, weight_decay=0.01, \n",
    "                                           logging_strategy='epoch',\n",
    "                                           push_to_hub=False)\n",
    "s_config = AutoConfig.from_pretrained(s_ckpt, num_labels=num_labels, \n",
    "                                      id2label=t_model.config.id2label, label2id=t_model.config.label2id)\n",
    "# s_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "86ba1a09",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:50:32.717059Z",
     "start_time": "2023-07-08T13:50:32.709207Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "def student_init():\n",
    "    return AutoModelForSequenceClassification.from_pretrained(s_ckpt, config=s_config).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0294067e",
   "metadata": {},
   "source": [
    "### trainer.train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8026a1ec",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:50:45.557236Z",
     "start_time": "2023-07-08T13:50:41.094296Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef4d0fd0e7cf4f89ba1c33915f31124f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# from datasets import load_metric\n",
    "# accuracy_score = load_metric('accuracy')\n",
    "# SequenceClassification\n",
    "import evaluate\n",
    "accuracy_score = evaluate.load('accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "6d9f99a2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:51:05.438475Z",
     "start_time": "2023-07-08T13:51:05.430725Z"
    }
   },
   "outputs": [],
   "source": [
    "# trainer 重要的回调函数，非成员函数\n",
    "def compute_metrics(pred):\n",
    "    predictions, labels = pred\n",
    "    predictions = np.argmax(predictions, axis=-1)\n",
    "    return accuracy_score.compute(references=labels, predictions=predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c635f0b7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:54:04.344304Z",
     "start_time": "2023-07-08T13:51:39.054134Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1286df8eea81463c920fed9191402439",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight']\n",
      "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight']\n",
      "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9399750910eb4585953c636d584a755a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016668651233332336, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.15.5 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.0"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/whaow/workspaces/bert_t5_gpt/tutorials/wandb/run-20230708_215253-e5di60uj</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/loveresearch/huggingface/runs/e5di60uj' target=\"_blank\">vague-sky-43</a></strong> to <a href='https://wandb.ai/loveresearch/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/loveresearch/huggingface' target=\"_blank\">https://wandb.ai/loveresearch/huggingface</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/loveresearch/huggingface/runs/e5di60uj' target=\"_blank\">https://wandb.ai/loveresearch/huggingface/runs/e5di60uj</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='600' max='600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [600/600 00:57, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.206500</td>\n",
       "      <td>0.401095</td>\n",
       "      <td>0.912903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.280000</td>\n",
       "      <td>0.318143</td>\n",
       "      <td>0.943871</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.209600</td>\n",
       "      <td>0.278893</td>\n",
       "      <td>0.953226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.190100</td>\n",
       "      <td>0.268503</td>\n",
       "      <td>0.956452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.183800</td>\n",
       "      <td>0.266352</td>\n",
       "      <td>0.957742</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=600, training_loss=0.41400583267211916, metrics={'train_runtime': 76.8404, 'train_samples_per_second': 992.317, 'train_steps_per_second': 7.808, 'total_flos': 456233053284036.0, 'train_loss': 0.41400583267211916, 'epoch': 5.0})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "distill_trainer = DistillTrainer(model_init=student_init, teacher_model=t_model, args=s_training_args, \n",
    "                                 train_dataset=clinc_enc['train'], eval_dataset=clinc_enc['validation'], \n",
    "                                 compute_metrics=compute_metrics, tokenizer=s_tokenizer)\n",
    "distill_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "820b9f48",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:54:24.687818Z",
     "start_time": "2023-07-08T13:54:24.678385Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "600"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "math.ceil(15250/(64*2)) * 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "740c3a00",
   "metadata": {},
   "source": [
    "### 使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69a441ae",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-05T13:51:23.505698Z",
     "start_time": "2023-07-05T13:51:23.498648Z"
    }
   },
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b6a527d3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-07-08T13:55:04.861128Z",
     "start_time": "2023-07-08T13:55:03.957835Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# ft_ckpt = 'lanchunhui/distilbert-base-uncased-ft-clinc'\n",
    "# distill_trainer.push_to_hub('finetune completed!')\n",
    "\n",
    "pipe = pipeline('text-classification', model='./distilbert-base-uncased-ft-clinc/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd8705e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "250px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
